{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "71a329f0",
   "metadata": {},
   "source": [
    "# Customized preprocessing and postprocessing (in TRTLLM)\n",
    "In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.\n",
    "\n",
    "Please make sure the following permission granted before running the notebook:\n",
    "\n",
    "- S3 bucket push access\n",
    "- SageMaker access\n",
    "\n",
    "## Step 1: Let's bump up SageMaker and import stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67fa3208",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install sagemaker --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec9ac353",
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "from sagemaker import Model, image_uris, serializers, deserializers\n",
    "\n",
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs\n",
    "region = sess._region_name  # region name of the current SageMaker Studio environment\n",
    "account_id = sess.account_id()  # account_id of the current SageMaker Studio environment"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81deac79",
   "metadata": {},
   "source": [
    "## Step 2: Start preparing model artifacts\n",
    "In LMI contianer, we expect some artifacts to help setting up the model\n",
    "- serving.properties (required): Defines the model server settings\n",
    "- model.py (optional): A python file to define the core inference logic\n",
    "- requirements.txt (optional): Any additional pip wheel need to install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b011bf5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile serving.properties\n",
    "engine=MPI\n",
    "option.model_id=TheBloke/Llama-2-7B-fp16\n",
    "option.tensor_parallel_degree=4\n",
    "option.max_rolling_batch_size=16\n",
    "option.trust_remote_code=true"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a0ab053",
   "metadata": {},
   "source": [
    "In this step, we will try to override the [default TensorRT-LLM handler](https://github.com/deepjavalibrary/djl-serving/blob/0.27.0-dlc/engines/python/setup/djl_python/tensorrt_llm.py) provided by DJLServing. We will replace the output formatter with `custom_output_formatter`, which outputs the token id, text, and log probability instead of just the text. \n",
    "\n",
    "We will also replace the input formatter with `custom_input_formatter` to accept \"prompt\" instead of \"inputs\" in the request (e.g. `{\"prompt\": \"...\", \"parameters\": {}}` is now a valid request instead of `{\"inputs\": \"...\", \"parameters\": {}}`\n",
    "\n",
    "You can replace either of these functions with your own custom input formatter and output formatter. The only restrictions are as follows:\n",
    "\n",
    "Input Formatter\n",
    "- Returns a 5-tuple of the following:\n",
    "  - A list of strings (prompt)\n",
    "  - An int (size of input)\n",
    "  - A dictionary (containing settings like top_k, temperature, etc.)\n",
    "  - A dictionary (for error logging)\n",
    "  - A list of Input objects (just use `inputs.get_batches()`)\n",
    " \n",
    "Output Formatter\n",
    "- 5 required parameters (these will be sent into the output formatter by the service):\n",
    "  - a Token object (defined [here](https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/rolling_batch/rolling_batch.py))\n",
    "  - a boolean denoting if this is the first token\n",
    "  - a boolean denoting if this is the last token\n",
    "  - a dictionary with miscellaneous information (e.g. finish reason)\n",
    "  - a string containing previously generated tokens\n",
    "- Returns a string\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19d6798b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile model.py\n",
    "from djl_python.tensorrt_llm import TRTLLMService\n",
    "from djl_python.inputs import Input\n",
    "from djl_python.outputs import Output\n",
    "from djl_python.encode_decode import encode, decode\n",
    "import logging\n",
    "import json\n",
    "import types\n",
    "\n",
    "_service = TRTLLMService()\n",
    "\n",
    "def custom_output_formatter(token, first_token, last_token, details, generated_tokens):\n",
    "    \"\"\"\n",
    "    Replace this function with your custom output formatter.\n",
    "    \n",
    "    Args:\n",
    "        token (Token): Token object \n",
    "        first (bool): If first token \n",
    "        last (bool): If last token\n",
    "        aux (dict): Miscellaneous information\n",
    "        prev_response (str): Previously generated tokens\n",
    "\n",
    "    Returns:\n",
    "        (str): Response string\n",
    "        \n",
    "    \"\"\"\n",
    "    result = {\"token_id\": token.id, \"token_text\": token.text, \"token_log_prob\": token.log_prob}\n",
    "    if last_token:\n",
    "        result[\"finish_reason\"] = details[\"finish_reason\"]\n",
    "    return json.dumps(result) + \"\\n\"\n",
    "\n",
    "def custom_input_formatter(self, inputs):\n",
    "    \"\"\"\n",
    "    Replace this function with your custom input formatter.\n",
    "        \n",
    "    Args:\n",
    "        data (obj): The request data, dict or string  \n",
    "\n",
    "    Returns:\n",
    "        (tuple): input_data (list), input_size (list), parameters (dict), errors (dict), batch (list)\n",
    "    \"\"\"\n",
    "    input_data = []\n",
    "    input_size = []\n",
    "    parameters = []\n",
    "    errors = {}\n",
    "    batch = inputs.get_batches()\n",
    "    for i, item in enumerate(batch):\n",
    "        try:\n",
    "            content_type = item.get_property(\"Content-Type\")\n",
    "            input_map = decode(item, content_type)\n",
    "        except Exception as e:  # pylint: disable=broad-except\n",
    "            logging.warning(f\"Parse input failed: {i}\")\n",
    "            input_size.append(0)\n",
    "            errors[i] = str(e)\n",
    "            continue\n",
    "\n",
    "        _inputs = input_map.pop(\"prompt\", input_map)\n",
    "        if not isinstance(_inputs, list):\n",
    "            _inputs = [_inputs]\n",
    "        input_data.extend(_inputs)\n",
    "        input_size.append(len(_inputs))\n",
    "\n",
    "        _param = input_map.pop(\"parameters\", {})\n",
    "        if \"cached_prompt\" in input_map:\n",
    "            _param[\"cached_prompt\"] = input_map.pop(\"cached_prompt\")\n",
    "        if not \"seed\" in _param:\n",
    "            # set server provided seed if seed is not part of request\n",
    "            if item.contains_key(\"seed\"):\n",
    "                _param[\"seed\"] = item.get_as_string(key=\"seed\")\n",
    "        for _ in range(input_size[i]):\n",
    "            parameters.append(_param)\n",
    "\n",
    "    return input_data, input_size, parameters, errors, batch\n",
    "\n",
    "def handle(inputs: Input):\n",
    "    \"\"\"\n",
    "    Default handler function\n",
    "    \"\"\"\n",
    "    if not _service.initialized:\n",
    "        # stateful model\n",
    "        props = inputs.get_properties()\n",
    "        props['output_formatter'] = custom_output_formatter\n",
    "        _service.initialize(props)\n",
    "        _service.parse_input = types.MethodType(custom_input_formatter, _service)\n",
    "\n",
    "    if inputs.is_empty():\n",
    "        # initialization request\n",
    "        return None\n",
    "\n",
    "    return _service.inference(inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0142973",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sh\n",
    "mkdir mymodel\n",
    "mv serving.properties mymodel/\n",
    "mv model.py mymodel/\n",
    "tar czvf mymodel.tar.gz mymodel/\n",
    "rm -rf mymodel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e58cf33",
   "metadata": {},
   "source": [
    "## Step 3: Start building SageMaker endpoint\n",
    "In this step, we will build SageMaker endpoint from scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d955679",
   "metadata": {},
   "source": [
    "### Getting the container image URI\n",
    "\n",
    "[Large Model Inference available DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a174b36",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_uri = image_uris.retrieve(\n",
    "        framework=\"djl-tensorrtllm\",\n",
    "        region=sess.boto_session.region_name,\n",
    "        version=\"0.27.0\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11601839",
   "metadata": {},
   "source": [
    "### Upload artifact on S3 and create SageMaker model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38b1e5ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_code_prefix = \"large-model-lmi/code\"\n",
    "bucket = sess.default_bucket()  # bucket to house artifacts\n",
    "code_artifact = sess.upload_data(\"mymodel.tar.gz\", bucket, s3_code_prefix)\n",
    "print(f\"S3 Code or Model tar ball uploaded to --- > {code_artifact}\")\n",
    "\n",
    "model = Model(image_uri=image_uri, model_data=code_artifact, role=role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "004f39f6",
   "metadata": {},
   "source": [
    "### 4.2 Create SageMaker endpoint\n",
    "\n",
    "You need to specify the instance to use and endpoint names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e0e61cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_type = \"ml.g5.12xlarge\"\n",
    "endpoint_name = sagemaker.utils.name_from_base(\"lmi-model\")\n",
    "\n",
    "model.deploy(initial_instance_count=1,\n",
    "             instance_type=instance_type,\n",
    "             endpoint_name=endpoint_name,\n",
    "             # container_startup_health_check_timeout=3600\n",
    "            )\n",
    "\n",
    "# our requests and responses will be in json format so we specify the serializer and the deserializer\n",
    "predictor = sagemaker.Predictor(\n",
    "    endpoint_name=endpoint_name,\n",
    "    sagemaker_session=sess,\n",
    "    serializer=serializers.JSONSerializer(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb63ee65",
   "metadata": {},
   "source": [
    "## Step 5: Test and benchmark the inference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24d6bf6d",
   "metadata": {},
   "source": [
    "Since we've changed the input preprocessing, the following will no longer work since the \"inputs\" field is no longer recognized:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9789399",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.predict(\n",
    "    {\"inputs\": \"Large model inference is\", \"parameters\": {}}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86ef6c50-d3b7-4d6b-97a9-d4da38b41d73",
   "metadata": {},
   "source": [
    "But this will work:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41bde6fc-74f1-4a1c-b467-49b508fa4f61",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.predict(\n",
    "    {\"prompt\": \"Large model inference is\", \"parameters\": {}}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ad3b98c-9b32-4549-8673-8aefdb2e6a17",
   "metadata": {},
   "source": [
    "Notice that the output format looks different compared to the output format in [an example without customized postprocessing](https://github.com/deepjavalibrary/djl-demo/blob/master/aws/sagemaker/large-model-inference/sample-llm/trtllm_rollingbatch_deploy_llama_13b.ipynb) because we changed the output formatter."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1cd9042",
   "metadata": {},
   "source": [
    "## Clean up the environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d674b41",
   "metadata": {},
   "outputs": [],
   "source": [
    "sess.delete_endpoint(endpoint_name)\n",
    "sess.delete_endpoint_config(endpoint_name)\n",
    "model.delete_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:amazon1] *",
   "language": "python",
   "name": "conda-env-amazon1-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
